import copy
import pickle
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

from .once_toolkits import Octopus
from ..dataset import DatasetTemplate
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils


class ONCEDataset(DatasetTemplate):
    def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
        """
        Args:
            root_path:
            dataset_cfg:
            class_names:
            training:
            logger:
        """
        super().__init__(
            dataset_cfg=dataset_cfg,
            class_names=class_names,
            training=training,
            root_path=root_path,
            logger=logger,
        )
        self.split = (
            dataset_cfg.DATA_SPLIT["train"] if training else dataset_cfg.DATA_SPLIT["test"]
        )
        assert self.split in ["train", "val", "test", "raw_small", "raw_medium", "raw_large"]

        split_dir = self.root_path / "ImageSets" / (self.split + ".txt")
        self.sample_seq_list = (
            [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
        )
        self.cam_names = ["cam01", "cam03", "cam05", "cam06", "cam07", "cam08", "cam09"]
        self.cam_tags = [
            "top",
            "top2",
            "left_back",
            "left_front",
            "right_front",
            "right_back",
            "back",
        ]
        self.toolkits = Octopus(self.root_path)

        self.once_infos = []
        self.include_once_data(self.split)

    def include_once_data(self, split):
        if self.logger is not None:
            self.logger.info("Loading ONCE dataset")
        once_infos = []

        for info_path in self.dataset_cfg.INFO_PATH[split]:
            info_path = self.root_path / info_path
            if not info_path.exists():
                continue
            with open(info_path, "rb") as f:
                infos = pickle.load(f)
                once_infos.extend(infos)

        def check_annos(info):
            return "annos" in info

        if self.split != "raw":
            once_infos = list(filter(check_annos, once_infos))

        self.once_infos.extend(once_infos)

        if self.logger is not None:
            self.logger.info("Total samples for ONCE dataset: %d" % (len(once_infos)))

    def set_split(self, split):
        super().__init__(
            dataset_cfg=self.dataset_cfg,
            class_names=self.class_names,
            training=self.training,
            root_path=self.root_path,
            logger=self.logger,
        )
        self.split = split

        split_dir = self.root_path / "ImageSets" / (self.split + ".txt")
        self.sample_seq_list = (
            [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
        )

    def get_lidar(self, sequence_id, frame_id):
        return self.toolkits.load_point_cloud(sequence_id, frame_id)

    def get_image(self, sequence_id, frame_id, cam_name):
        return self.toolkits.load_image(sequence_id, frame_id, cam_name)

    def project_lidar_to_image(self, sequence_id, frame_id):
        return self.toolkits.project_lidar_to_image(sequence_id, frame_id)

    def point_painting(self, points, info):
        semseg_dir = "./"  # add your own seg directory
        used_classes = [0, 1, 2, 3, 4, 5]
        num_classes = len(used_classes)
        frame_id = str(info["frame_id"])
        seq_id = str(info["sequence_id"])
        painted = np.zeros((points.shape[0], num_classes))  # classes + bg
        for cam_name in self.cam_names:
            img_path = (
                Path(semseg_dir) / Path(seq_id) / Path(cam_name) / Path(frame_id + "_label.png")
            )
            calib_info = info["calib"][cam_name]
            cam_2_velo = calib_info["cam_to_velo"]
            cam_intri = np.hstack(
                [calib_info["cam_intrinsic"], np.zeros((3, 1), dtype=np.float32)]
            )
            point_xyz = points[:, :3]
            points_homo = np.hstack(
                [point_xyz, np.ones(point_xyz.shape[0], dtype=np.float32).reshape((-1, 1))]
            )
            points_lidar = np.dot(points_homo, np.linalg.inv(cam_2_velo).T)
            mask = points_lidar[:, 2] > 0
            points_lidar = points_lidar[mask]
            points_img = np.dot(points_lidar, cam_intri.T)
            points_img = points_img / points_img[:, [2]]
            uv = points_img[:, [0, 1]]
            # depth = points_img[:, [2]]
            seg_map = np.array(Image.open(img_path))  # (H, W)
            H, W = seg_map.shape
            seg_feats = np.zeros((H * W, num_classes))
            seg_map = seg_map.reshape(-1)
            for cls_i in used_classes:
                seg_feats[seg_map == cls_i, cls_i] = 1
            seg_feats = seg_feats.reshape(H, W, num_classes).transpose(2, 0, 1)
            uv[:, 0] = (uv[:, 0] - W / 2) / (W / 2)
            uv[:, 1] = (uv[:, 1] - H / 2) / (H / 2)
            uv_tensor = torch.from_numpy(uv).unsqueeze(0).unsqueeze(0)  # [1,1,N,2]
            seg_feats = torch.from_numpy(seg_feats).unsqueeze(0)  # [1,C,H,W]
            proj_scores = F.grid_sample(
                seg_feats, uv_tensor, mode="bilinear", padding_mode="zeros"
            )  # [1, C, 1, N]
            proj_scores = proj_scores.squeeze(0).squeeze(1).transpose(0, 1).contiguous()  # [N, C]
            painted[mask] = proj_scores.numpy()
        return np.concatenate([points, painted], axis=1)

    def __len__(self):
        if self._merge_all_iters_to_one_epoch:
            return len(self.once_infos) * self.total_epochs

        return len(self.once_infos)

    def __getitem__(self, index):
        if self._merge_all_iters_to_one_epoch:
            index = index % len(self.once_infos)

        info = copy.deepcopy(self.once_infos[index])
        frame_id = info["frame_id"]
        seq_id = info["sequence_id"]
        points = self.get_lidar(seq_id, frame_id)

        if self.dataset_cfg.get("POINT_PAINTING", False):
            points = self.point_painting(points, info)

        input_dict = {
            "points": points,
            "frame_id": frame_id,
        }

        if "annos" in info:
            annos = info["annos"]
            input_dict.update(
                {
                    "gt_names": annos["name"],
                    "gt_boxes": annos["boxes_3d"],
                    "num_points_in_gt": annos.get("num_points_in_gt", None),
                }
            )

        data_dict = self.prepare_data(data_dict=input_dict)
        data_dict.pop("num_points_in_gt", None)
        return data_dict

    def get_infos(self, num_workers=4, sample_seq_list=None):
        import concurrent.futures as futures
        import json

        root_path = self.root_path
        cam_names = self.cam_names

        """
        # dataset json format
        {
            'meta_info': 
            'calib': {
                'cam01': {
                    'cam_to_velo': list
                    'cam_intrinsic': list
                    'distortion': list
                }
                ...
            }
            'frames': [
                {
                    'frame_id': timestamp,
                    'annos': {
                        'names': list
                        'boxes_3d': list of list
                        'boxes_2d': {
                            'cam01': list of list
                            ...
                        }
                    }
                    'pose': list
                },
                ...
            ]
        }
        # open pcdet format
        {
            'meta_info':
            'sequence_id': seq_idx
            'frame_id': timestamp
            'timestamp': timestamp
            'lidar': path
            'cam01': path
            ...
            'calib': {
                'cam01': {
                    'cam_to_velo': np.array
                    'cam_intrinsic': np.array
                    'distortion': np.array
                }
                ...
            }
            'pose': np.array
            'annos': {
                'name': np.array
                'boxes_3d': np.array
                'boxes_2d': {
                    'cam01': np.array
                    ....
                }
            }          
        }
        """

        def process_single_sequence(seq_idx):
            print("%s seq_idx: %s" % (self.split, seq_idx))
            seq_infos = []
            seq_path = Path(root_path) / "data" / seq_idx
            json_path = seq_path / ("%s.json" % seq_idx)
            with open(json_path, "r") as f:
                info_this_seq = json.load(f)
            meta_info = info_this_seq["meta_info"]
            calib = info_this_seq["calib"]
            for f_idx, frame in enumerate(info_this_seq["frames"]):
                frame_id = frame["frame_id"]
                if f_idx == 0:
                    prev_id = None
                else:
                    prev_id = info_this_seq["frames"][f_idx - 1]["frame_id"]
                if f_idx == len(info_this_seq["frames"]) - 1:
                    next_id = None
                else:
                    next_id = info_this_seq["frames"][f_idx + 1]["frame_id"]
                pc_path = str(seq_path / "lidar_roof" / ("%s.bin" % frame_id))
                pose = np.array(frame["pose"])
                frame_dict = {
                    "sequence_id": seq_idx,
                    "frame_id": frame_id,
                    "timestamp": int(frame_id),
                    "prev_id": prev_id,
                    "next_id": next_id,
                    "meta_info": meta_info,
                    "lidar": pc_path,
                    "pose": pose,
                }
                calib_dict = {}
                for cam_name in cam_names:
                    cam_path = str(seq_path / cam_name / ("%s.jpg" % frame_id))
                    frame_dict.update({cam_name: cam_path})
                    calib_dict[cam_name] = {}
                    calib_dict[cam_name]["cam_to_velo"] = np.array(calib[cam_name]["cam_to_velo"])
                    calib_dict[cam_name]["cam_intrinsic"] = np.array(
                        calib[cam_name]["cam_intrinsic"]
                    )
                    calib_dict[cam_name]["distortion"] = np.array(calib[cam_name]["distortion"])
                frame_dict.update({"calib": calib_dict})

                if "annos" in frame:
                    annos = frame["annos"]
                    boxes_3d = np.array(annos["boxes_3d"])
                    if boxes_3d.shape[0] == 0:
                        print(frame_id)
                        continue
                    boxes_2d_dict = {}
                    for cam_name in cam_names:
                        boxes_2d_dict[cam_name] = np.array(annos["boxes_2d"][cam_name])
                    annos_dict = {
                        "name": np.array(annos["names"]),
                        "boxes_3d": boxes_3d,
                        "boxes_2d": boxes_2d_dict,
                    }

                    points = self.get_lidar(seq_idx, frame_id)
                    corners_lidar = box_utils.boxes_to_corners_3d(np.array(annos["boxes_3d"]))
                    num_gt = boxes_3d.shape[0]
                    num_points_in_gt = -np.ones(num_gt, dtype=np.int32)
                    for k in range(num_gt):
                        flag = box_utils.in_hull(points[:, 0:3], corners_lidar[k])
                        num_points_in_gt[k] = flag.sum()
                    annos_dict["num_points_in_gt"] = num_points_in_gt

                    frame_dict.update({"annos": annos_dict})
                seq_infos.append(frame_dict)
            return seq_infos

        sample_seq_list = sample_seq_list if sample_seq_list is not None else self.sample_seq_list
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_sequence, sample_seq_list)
        all_infos = []
        for info in infos:
            all_infos.extend(info)
        return all_infos

    def create_groundtruth_database(self, info_path=None, used_classes=None, split="train"):
        import torch

        database_save_path = Path(self.root_path) / (
            "gt_database" if split == "train" else ("gt_database_%s" % split)
        )
        db_info_save_path = Path(self.root_path) / ("once_dbinfos_%s.pkl" % split)

        database_save_path.mkdir(parents=True, exist_ok=True)
        all_db_infos = {}

        with open(info_path, "rb") as f:
            infos = pickle.load(f)

        for k in range(len(infos)):
            if "annos" not in infos[k]:
                continue
            print("gt_database sample: %d" % (k + 1))
            info = infos[k]
            frame_id = info["frame_id"]
            seq_id = info["sequence_id"]
            points = self.get_lidar(seq_id, frame_id)
            annos = info["annos"]
            names = annos["name"]
            gt_boxes = annos["boxes_3d"]

            num_obj = gt_boxes.shape[0]
            point_indices = roiaware_pool3d_utils.points_in_boxes_cpu(
                torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes)
            ).numpy()  # (nboxes, npoints)

            for i in range(num_obj):
                filename = "%s_%s_%d.bin" % (frame_id, names[i], i)
                filepath = database_save_path / filename
                gt_points = points[point_indices[i] > 0]

                gt_points[:, :3] -= gt_boxes[i, :3]
                with open(filepath, "w") as f:
                    gt_points.tofile(f)

                db_path = str(filepath.relative_to(self.root_path))  # gt_database/xxxxx.bin
                db_info = {
                    "name": names[i],
                    "path": db_path,
                    "gt_idx": i,
                    "box3d_lidar": gt_boxes[i],
                    "num_points_in_gt": gt_points.shape[0],
                }
                if names[i] in all_db_infos:
                    all_db_infos[names[i]].append(db_info)
                else:
                    all_db_infos[names[i]] = [db_info]

        for k, v in all_db_infos.items():
            print("Database %s: %d" % (k, len(v)))

        with open(db_info_save_path, "wb") as f:
            pickle.dump(all_db_infos, f)

    @staticmethod
    def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
        def get_template_prediction(num_samples):
            ret_dict = {
                "name": np.zeros(num_samples),
                "score": np.zeros(num_samples),
                "boxes_3d": np.zeros((num_samples, 7)),
            }
            return ret_dict

        def generate_single_sample_dict(box_dict):
            pred_scores = box_dict["pred_scores"].cpu().numpy()
            pred_boxes = box_dict["pred_boxes"].cpu().numpy()
            pred_labels = box_dict["pred_labels"].cpu().numpy()
            pred_dict = get_template_prediction(pred_scores.shape[0])
            if pred_scores.shape[0] == 0:
                return pred_dict

            pred_dict["name"] = np.array(class_names)[pred_labels - 1]
            pred_dict["score"] = pred_scores
            pred_dict["boxes_3d"] = pred_boxes
            return pred_dict

        annos = []
        for index, box_dict in enumerate(pred_dicts):
            frame_id = batch_dict["frame_id"][index]
            single_pred_dict = generate_single_sample_dict(box_dict)
            single_pred_dict["frame_id"] = frame_id
            annos.append(single_pred_dict)

            if output_path is not None:
                raise NotImplementedError
        return annos

    def evaluation(self, det_annos, class_names, **kwargs):
        from .once_eval.evaluation import get_evaluation_results

        eval_det_annos = copy.deepcopy(det_annos)
        eval_gt_annos = [copy.deepcopy(info["annos"]) for info in self.once_infos]
        ap_result_str, ap_dict = get_evaluation_results(eval_gt_annos, eval_det_annos, class_names)

        return ap_result_str, ap_dict


def create_once_infos(dataset_cfg, class_names, data_path, save_path, workers=4):
    dataset = ONCEDataset(
        dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False
    )

    splits = ["train", "val", "test", "raw_small", "raw_medium", "raw_large"]
    ignore = ["test"]

    print("---------------Start to generate data infos---------------")
    for split in splits:
        if split in ignore:
            continue

        filename = "once_infos_%s.pkl" % split
        filename = save_path / Path(filename)
        dataset.set_split(split)
        once_infos = dataset.get_infos(num_workers=workers)
        with open(filename, "wb") as f:
            pickle.dump(once_infos, f)
        print("ONCE info %s file is saved to %s" % (split, filename))

    train_filename = save_path / "once_infos_train.pkl"
    print("---------------Start create groundtruth database for data augmentation---------------")
    dataset.set_split("train")
    dataset.create_groundtruth_database(train_filename, split="train")
    print("---------------Data preparation Done---------------")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="arg parser")
    parser.add_argument("--cfg_file", type=str, default=None, help="specify the config of dataset")
    parser.add_argument("--func", type=str, default="create_waymo_infos", help="")
    parser.add_argument("--runs_on", type=str, default="server", help="")
    args = parser.parse_args()

    if args.func == "create_once_infos":
        from pathlib import Path

        import yaml
        from easydict import EasyDict

        dataset_cfg = EasyDict(yaml.load(open(args.cfg_file)))

        ROOT_DIR = (Path(__file__).resolve().parent / "../../../").resolve()
        once_data_path = ROOT_DIR / "data" / "once"
        once_save_path = ROOT_DIR / "data" / "once"

        if args.runs_on == "cloud":
            once_data_path = Path("/cache/once/")
            once_save_path = Path("/cache/once/")
            dataset_cfg.DATA_PATH = dataset_cfg.CLOUD_DATA_PATH

        create_once_infos(
            dataset_cfg=dataset_cfg,
            class_names=["Car", "Bus", "Truck", "Pedestrian", "Bicycle"],
            data_path=once_data_path,
            save_path=once_save_path,
        )
